{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT（Bidirectional Encoder Representations from Transformers）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NLP里的迁移学习\n",
    "- 使用预训练好的模型来抽取词、句子的特征\n",
    "    - 例如word2vec或语言模型\n",
    "- 不更新预训练好的模型\n",
    "- 需要构建新的网络来抓取新任务需要的信息.\n",
    "    - Word2vec忽略了时序信息，\n",
    "    - 语言模型只看了一个方向"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT的动机\n",
    "- 基于微调的NLP模型\n",
    "<img src='bert0.png' align='middle'>\n",
    "- 预训练的模型抽取了足够多的信息\n",
    "- 新的任务只需要增加一个简单的输出层\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT架构\n",
    "- 只有编码器的Transformer\n",
    "- 两个版本:\n",
    "    - Base: #blocks = 12, hidden size = 768, #heads= 12,#parameters = 110M\n",
    "    - Large: #blocks = 24, hidden size = 1024, #heads = 16,#parameter = 340M\n",
    "- 在大规模数据上训练>3B词 (Wikipedia, books)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT的创新点\n",
    "\n",
    "（1）对输入的修改\n",
    "- 每个样本是一个句子对\n",
    "    - <CLS> classification作为开头，两个分隔符<sep>把两个句子分开。\n",
    "- 加入额外的片段嵌入（segment embedding）\n",
    "    - 第一个句子是0，第二个句子是1 \n",
    "- 位置编码可学习\n",
    "\n",
    "<img src='bert1.png'>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "（2）预训练任务1∶带掩码的语言模型\n",
    "- Transfomer的编码器是双向，标准语言模型要求单向\n",
    "- 带掩码的语言模型每次随机(15%概率)将一些词元（token）换成 \\<mask>\n",
    "- 因为微调任务中不出现\\<mask>\n",
    "    - 80%概率下，将选中的词元变成\\<mask>\n",
    "    - 10%概率下换成一个随机词元\n",
    "    - 10%概率下保持原有的词元\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "（3）预训练任务2:下一句子预测\n",
    "- 预测一个句子对中两个句子是不是相邻\n",
    "- 训练样本中:\n",
    "    - 50%概率选择相邻句子对:<cls> this movie is great <sep> i like it <sep>\n",
    "    - 50%概率选择随机句子对:<cls> this movicie is great <sep> hello world <sep>\n",
    "- 将<cls>对应的输出放到一个全连接层来预测\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结\n",
    "- BERT针对NLP的微调设计\n",
    "- 基于Transformer的编码器做了如下修改\n",
    "    - 模型更大，训练数据更多\n",
    "    - 输入句子对，片段嵌入，可学习的位置编码\n",
    "    - 训练时使用两个任务:\n",
    "        - 带掩码的语言模型\n",
    "        - 下一个句子预测\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BERT模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Input Representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokens_and_segments(tokens_a, tokens_b=None):\n",
    "    \"\"\"Get tokens of the BERT input sequence and their segment IDs.\"\"\"\n",
    "    tokens = ['<cls>'] + tokens_a + ['<sep>']\n",
    "    segments = [0] * (len(tokens_a)+2)\n",
    "    if tokens_b is not None:\n",
    "        tokens += tokens_b + ['<sep>']\n",
    "        segments += [1] * (len(tokens_b) + 1)\n",
    "    return tokens, segments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT Encoder class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BERTEncoder(nn.Module):\n",
    "    \"\"\"BERT encoder.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,\n",
    "                 ffn_num_hiddens, num_heads, num_layers, dropout,\n",
    "                 max_len=1000, key_size=768, query_size=768, value_size=768,\n",
    "                 **kwargs):\n",
    "        super(BERTEncoder, self).__init__(**kwargs)\n",
    "        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)\n",
    "        self.segment_embedding = nn.Embedding(2, num_hiddens) # BERT新加的\n",
    "        self.blks = nn.Sequential()\n",
    "        for i in range(num_layers): # 把Transformer Block搬过来\n",
    "            self.blks.add_module(f\"{i}\", d2l.EncoderBlock(\n",
    "                key_size, query_size, value_size, num_hiddens, norm_shape,\n",
    "                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,\n",
    "                                                      num_hiddens))  # 随机初始化，变为可以学习的位置编码\n",
    "\n",
    "    def forward(self, tokens, segments, valid_lens):\n",
    "        X = self.token_embedding(tokens) + self.segment_embedding(segments)\n",
    "        X = X + self.pos_embedding.data[:, :X.shape[1], :]\n",
    "        for blk in self.blks:\n",
    "            X = blk(X, valid_lens)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inference of BERT Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import multiprocessing\n",
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "from mxnet import gluon\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4\n",
    "norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2\n",
    "encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,\n",
    "                      ffn_num_hiddens, num_heads, num_layers, dropout)\n",
    "\n",
    "tokens = torch.randint(0, vocab_size, (2, 8))\n",
    "segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])\n",
    "encoded_X = encoder(tokens, segments, None)\n",
    "encoded_X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Masked Language Modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaskLM(nn.Module):\n",
    "    \"\"\"The masked language model task of BERT.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):\n",
    "        super(MaskLM, self).__init__(**kwargs)\n",
    "        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),\n",
    "                                 nn.ReLU(),\n",
    "                                 nn.LayerNorm(num_hiddens),\n",
    "                                 nn.Linear(num_hiddens, vocab_size))\n",
    "\n",
    "    def forward(self, X, pred_positions):\n",
    "        num_pred_positions = pred_positions.shape[1]\n",
    "        pred_positions = pred_positions.reshape(-1)\n",
    "        batch_size = X.shape[0]\n",
    "        batch_idx = torch.arange(0, batch_size)\n",
    "        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)\n",
    "        masked_X = X[batch_idx, pred_positions]\n",
    "        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))\n",
    "        mlm_Y_hat = self.mlp(masked_X)\n",
    "        return mlm_Y_hat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The forward inference of MaskLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlm = MaskLM(vocab_size, num_hiddens)\n",
    "mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])\n",
    "mlm_Y_hat = mlm(encoded_X, mlm_positions)\n",
    "mlm_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])\n",
    "loss = nn.CrossEntropyLoss(reduction='none')\n",
    "mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))\n",
    "mlm_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next Sentence Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NextSentencePred(nn.Module):\n",
    "    \"\"\"The next sentence prediction task of BERT.\"\"\"\n",
    "    def __init__(self, num_inputs, **kwargs):\n",
    "        super(NextSentencePred, self).__init__(**kwargs)\n",
    "        self.output = nn.Linear(num_inputs, 2)\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.output(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The forward inference of an NextSentencePred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_X = torch.flatten(encoded_X, start_dim=1)\n",
    "nsp = NextSentencePred(encoded_X.shape[-1])\n",
    "nsp_Y_hat = nsp(encoded_X)\n",
    "nsp_Y_hat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsp_y = torch.tensor([0, 1])\n",
    "nsp_l = loss(nsp_Y_hat, nsp_y)\n",
    "nsp_l.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Putting All Things Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BERTModel(nn.Module):\n",
    "    \"\"\"The BERT model.\"\"\"\n",
    "    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,\n",
    "                 ffn_num_hiddens, num_heads, num_layers, dropout,\n",
    "                 max_len=1000, key_size=768, query_size=768, value_size=768,\n",
    "                 hid_in_features=768, mlm_in_features=768,\n",
    "                 nsp_in_features=768):\n",
    "        super(BERTModel, self).__init__()\n",
    "        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,\n",
    "                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,\n",
    "                    dropout, max_len=max_len, key_size=key_size,\n",
    "                    query_size=query_size, value_size=value_size)\n",
    "        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),\n",
    "                                    nn.Tanh())\n",
    "        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)\n",
    "        self.nsp = NextSentencePred(nsp_in_features)\n",
    "\n",
    "    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):\n",
    "        encoded_X = self.encoder(tokens, segments, valid_lens)\n",
    "        if pred_positions is not None:\n",
    "            mlm_Y_hat = self.mlm(encoded_X, pred_positions)\n",
    "        else:\n",
    "            mlm_Y_hat = None\n",
    "        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))\n",
    "        return encoded_X, mlm_Y_hat, nsp_Y_hat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Dataset for Pretraining BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import torch\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The WikiText-2 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d2l.DATA_HUB['wikitext-2'] = (\n",
    "    'https://s3.amazonaws.com/research.metamind.io/wikitext/'\n",
    "    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')\n",
    "\n",
    "def _read_wiki(data_dir):\n",
    "    file_name = os.path.join(data_dir, 'wiki.train.tokens')\n",
    "    with open(file_name, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    paragraphs = [line.strip().lower().split(' . ')\n",
    "                  for line in lines if len(line.split(' . ')) >= 2]\n",
    "    random.shuffle(paragraphs)\n",
    "    return paragraphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generating the Next Sentence Prediction Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_next_sentence(sentence, next_sentence, paragraphs):\n",
    "    if random.random() < 0.5: # 50%概率下个句子是相邻的\n",
    "        is_next = True\n",
    "    else:\n",
    "        next_sentence = random.choice(random.choice(paragraphs)) # 50%概率下个句子不是相邻的\n",
    "        is_next = False\n",
    "    return sentence, next_sentence, is_next\n",
    "\n",
    "def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):\n",
    "    nsp_data_from_paragraph = []\n",
    "    for i in range(len(paragraph) - 1):\n",
    "        tokens_a, tokens_b, is_next = _get_next_sentence(\n",
    "            paragraph[i], paragraph[i + 1], paragraphs)  # is_next 是label\n",
    "        if len(tokens_a) + len(tokens_b) + 3 > max_len:\n",
    "            continue\n",
    "        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n",
    "        nsp_data_from_paragraph.append((tokens, segments, is_next))\n",
    "    return nsp_data_from_paragraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generating the Masked Language Modeling Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,\n",
    "                        vocab):\n",
    "    mlm_input_tokens = [token for token in tokens]\n",
    "    pred_positions_and_labels = []\n",
    "    random.shuffle(candidate_pred_positions)\n",
    "    for mlm_pred_position in candidate_pred_positions:\n",
    "        if len(pred_positions_and_labels) >= num_mlm_preds:\n",
    "            break\n",
    "        masked_token = None\n",
    "        if random.random() < 0.8:\n",
    "            masked_token = '<mask>'\n",
    "        else:\n",
    "            if random.random() < 0.5:\n",
    "                masked_token = tokens[mlm_pred_position]\n",
    "            else:\n",
    "                masked_token = random.randint(0, len(vocab) - 1)\n",
    "        mlm_input_tokens[mlm_pred_position] = masked_token\n",
    "        pred_positions_and_labels.append(\n",
    "            (mlm_pred_position, tokens[mlm_pred_position]))\n",
    "    return mlm_input_tokens, pred_positions_and_labels\n",
    "\n",
    "def _get_mlm_data_from_tokens(tokens, vocab):\n",
    "    candidate_pred_positions = []\n",
    "    # tokens是一个字符串列表\n",
    "    for i, token in enumerate(tokens):\n",
    "        # 在遮蔽语言模型任务中不会预测特殊词元\n",
    "        if token in ['<cls>', '<sep>']:\n",
    "            continue\n",
    "        candidate_pred_positions.append(i)\n",
    "    # 遮蔽语言模型任务中预测15%的随机词元\n",
    "    num_mlm_preds = max(1, round(len(tokens) * 0.15))\n",
    "    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(\n",
    "        tokens, candidate_pred_positions, num_mlm_preds, vocab)\n",
    "    pred_positions_and_labels = sorted(pred_positions_and_labels,\n",
    "                                       key=lambda x: x[0])\n",
    "    pred_positions = [v[0] for v in pred_positions_and_labels]\n",
    "    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]\n",
    "    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Append the special “<mask>” tokens to the inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _pad_bert_inputs(examples, max_len, vocab): # 如果两个token是拼成的长度少于max_len，就把其他padding下\n",
    "    max_num_mlm_preds = round(max_len * 0.15)\n",
    "    all_token_ids, all_segments, valid_lens,  = [], [], []\n",
    "    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []\n",
    "    nsp_labels = []\n",
    "    for (token_ids, pred_positions, mlm_pred_label_ids, segments,\n",
    "         is_next) in examples:\n",
    "        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (\n",
    "            max_len - len(token_ids)), dtype=torch.long))\n",
    "        all_segments.append(torch.tensor(segments + [0] * (\n",
    "            max_len - len(segments)), dtype=torch.long))\n",
    "        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))\n",
    "        all_pred_positions.append(torch.tensor(pred_positions + [0] * (\n",
    "            max_num_mlm_preds - len(pred_positions)), dtype=torch.long))\n",
    "        all_mlm_weights.append(\n",
    "            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (\n",
    "                max_num_mlm_preds - len(pred_positions)),\n",
    "                dtype=torch.float32))\n",
    "        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (\n",
    "            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))\n",
    "        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))\n",
    "    return (all_token_ids, all_segments, valid_lens, all_pred_positions,\n",
    "            all_mlm_weights, all_mlm_labels, nsp_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The WikiText-2 dataset for pretraining BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _WikiTextDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, paragraphs, max_len):\n",
    "        paragraphs = [d2l.tokenize(\n",
    "            paragraph, token='word') for paragraph in paragraphs]\n",
    "        sentences = [sentence for paragraph in paragraphs\n",
    "                     for sentence in paragraph]\n",
    "        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[\n",
    "            '<pad>', '<mask>', '<cls>', '<sep>'])\n",
    "        examples = []\n",
    "        for paragraph in paragraphs:\n",
    "            examples.extend(_get_nsp_data_from_paragraph(\n",
    "                paragraph, paragraphs, self.vocab, max_len))\n",
    "        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)\n",
    "                      + (segments, is_next))\n",
    "                     for tokens, segments, is_next in examples]\n",
    "        (self.all_token_ids, self.all_segments, self.valid_lens,\n",
    "         self.all_pred_positions, self.all_mlm_weights,\n",
    "         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(\n",
    "            examples, max_len, self.vocab)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (self.all_token_ids[idx], self.all_segments[idx],\n",
    "                self.valid_lens[idx], self.all_pred_positions[idx],\n",
    "                self.all_mlm_weights[idx], self.all_mlm_labels[idx],\n",
    "                self.nsp_labels[idx])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.all_token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download and WikiText-2 dataset and generate pretraining examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data_wiki(batch_size, max_len):\n",
    "    \"\"\"Load the WikiText-2 dataset.\"\"\"\n",
    "    num_workers = d2l.get_dataloader_workers()\n",
    "    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')\n",
    "    paragraphs = _read_wiki(data_dir)\n",
    "    train_set = _WikiTextDataset(paragraphs, max_len)\n",
    "    train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
    "                                        shuffle=True, num_workers=num_workers)\n",
    "    return train_iter, train_set.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print out the shapes of a minibatch of BERT pretraining examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, max_len = 512, 64\n",
    "train_iter, vocab = load_data_wiki(batch_size, max_len)\n",
    "\n",
    "for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,\n",
    "     mlm_Y, nsp_y) in train_iter:\n",
    "    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,\n",
    "          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,\n",
    "          nsp_y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pretraining BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "batch_size, max_len = 512, 64\n",
    "train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A small BERT, using 2 layers, 128 hidden units, and 2 self-attention heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],\n",
    "                    ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,\n",
    "                    num_layers=2, dropout=0.2, key_size=128, query_size=128,\n",
    "                    value_size=128, hid_in_features=128, mlm_in_features=128,\n",
    "                    nsp_in_features=128)\n",
    "devices = d2l.try_all_gpus()\n",
    "loss = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Computes the loss for both the masked language modeling and next sentence prediction tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n",
    "                         segments_X, valid_lens_x,\n",
    "                         pred_positions_X, mlm_weights_X,\n",
    "                         mlm_Y, nsp_y):\n",
    "    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n",
    "                                  valid_lens_x.reshape(-1),\n",
    "                                  pred_positions_X)\n",
    "    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n",
    "    mlm_weights_X.reshape(-1, 1)\n",
    "    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n",
    "    nsp_l = loss(nsp_Y_hat, nsp_y)\n",
    "    l = mlm_l + nsp_l\n",
    "    return mlm_l, nsp_l, l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pretrain BERT (net) on the WikiText-2 (train_iter) dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n",
    "    net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
    "    trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n",
    "    step, timer = 0, d2l.Timer()\n",
    "    animator = d2l.Animator(xlabel='step', ylabel='loss',\n",
    "                            xlim=[1, num_steps], legend=['mlm', 'nsp'])\n",
    "    # 遮蔽语言模型损失的和，下一句预测任务损失的和，句子对的数量，计数\n",
    "    metric = d2l.Accumulator(4)\n",
    "    num_steps_reached = False\n",
    "    while step < num_steps and not num_steps_reached:\n",
    "        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\\\n",
    "            mlm_weights_X, mlm_Y, nsp_y in train_iter:\n",
    "            tokens_X = tokens_X.to(devices[0])\n",
    "            segments_X = segments_X.to(devices[0])\n",
    "            valid_lens_x = valid_lens_x.to(devices[0])\n",
    "            pred_positions_X = pred_positions_X.to(devices[0])\n",
    "            mlm_weights_X = mlm_weights_X.to(devices[0])\n",
    "            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n",
    "            trainer.zero_grad()\n",
    "            timer.start()\n",
    "            mlm_l, nsp_l, l = _get_batch_loss_bert(\n",
    "                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,\n",
    "                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n",
    "            l.backward()\n",
    "            trainer.step()\n",
    "            metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n",
    "            timer.stop()\n",
    "            animator.add(step + 1,\n",
    "                         (metric[0] / metric[3], metric[1] / metric[3]))\n",
    "            step += 1\n",
    "            if step == num_steps:\n",
    "                num_steps_reached = True\n",
    "                break\n",
    "\n",
    "    print(f'MLM loss {metric[0] / metric[3]:.3f}, '\n",
    "          f'NSP loss {metric[1] / metric[3]:.3f}')\n",
    "    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '\n",
    "          f'{str(devices)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Representing Text with BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bert_encoding(net, tokens_a, tokens_b=None):\n",
    "    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n",
    "    token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n",
    "    segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n",
    "    valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n",
    "    encoded_X, _, _ = net(token_ids, segments, valid_len)\n",
    "    return encoded_X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Consider the sentence \"a crane is flying\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_a = ['a', 'crane', 'is', 'flying']\n",
    "encoded_text = get_bert_encoding(net, tokens_a)\n",
    "encoded_text_cls = encoded_text[:, 0, :]\n",
    "encoded_text_crane = encoded_text[:, 2, :]\n",
    "encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now consider a sentence pair \"a crane driver came\" and \"he just left\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n",
    "encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n",
    "encoded_pair_cls = encoded_pair[:, 0, :]\n",
    "encoded_pair_crane = encoded_pair[:, 2, :]\n",
    "encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
